{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a880bbe6-78b7-4056-95bd-e8e3f236f01a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "try:\n",
    "    # load environment variables from .env file (requires `python-dotenv`)\n",
    "    from dotenv import load_dotenv\n",
    "\n",
    "    _ = load_dotenv()\n",
    "except ImportError:\n",
    "    pass\n",
    "\n",
    "if not os.environ.get(\"DASHSCOPE_API_KEY\"):\n",
    "  os.environ[\"DASHSCOPE_API_KEY\"] = getpass.getpass(\"Enter API key for OpenAI: \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "af2b6aa7-e0df-42ff-933a-b7c5f33ba891",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.chat_models.tongyi import ChatTongyi\n",
    "\n",
    "llm  = ChatTongyi(\n",
    "    streaming=True,\n",
    "    name=\"qwen-turbo\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9265912d-3c45-4321-b3a8-065f4396b09c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.embeddings.dashscope import DashScopeEmbeddings\n",
    "# 初始化 Qwen Embedding 模型\n",
    "embeddings = DashScopeEmbeddings(model=\"text-embedding-v1\")  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7159eb33-4a79-47f5-bf4f-4cae066e7dda",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.vectorstores import InMemoryVectorStore\n",
    "\n",
    "vector_store = InMemoryVectorStore(embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4e448c09-7a5e-4721-ae68-005d01e7998e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "USER_AGENT environment variable not set, consider setting it to identify your requests.\n"
     ]
    }
   ],
   "source": [
    "import bs4\n",
    "from langchain import hub\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain_core.documents import Document\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "from typing_extensions import List, TypedDict\n",
    "\n",
    "# Load and chunk contents of the blog\n",
    "loader = WebBaseLoader(\n",
    "    web_paths=(\"https://gyxxh.tj.gov.cn/glllm/gabsycs/gxdtgh/202504/t20250408_6902492.html\",),\n",
    "    bs_kwargs=dict(\n",
    "        parse_only=bs4.SoupStrainer(\n",
    "            class_=(\"view-wrap\")\n",
    "        )\n",
    "    ),\n",
    ")\n",
    "docs = loader.load()\n",
    "\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
    "all_splits = text_splitter.split_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ce698cb7-74c2-4367-9571-f6db8b9179cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Index chunks\n",
    "_ = vector_store.add_documents(documents=all_splits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5ac197e-663f-44a1-8a4a-66b3a16ff027",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.tools import tool\n",
    "\n",
    "\n",
    "@tool(response_format=\"content_and_artifact\")\n",
    "def retrieve(query: str):\n",
    "    \"\"\"Retrieve information related to a query.\"\"\"\n",
    "    retrieved_docs = vector_store.similarity_search(query, k=2)\n",
    "    serialized = \"\\n\\n\".join(\n",
    "        (f\"Source: {doc.metadata}\\n\" f\"Content: {doc.page_content}\")\n",
    "        for doc in retrieved_docs\n",
    "    )\n",
    "    return serialized, retrieved_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "06f53d48-4864-4148-9bca-786370feb0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.messages import SystemMessage\n",
    "from langgraph.prebuilt import ToolNode\n",
    "from langgraph.graph import MessagesState\n",
    "\n",
    "# Step 1: Generate an AIMessage that may include a tool-call to be sent.\n",
    "def query_or_respond(state: MessagesState):\n",
    "    \"\"\"Generate tool call for retrieval or respond.\"\"\"\n",
    "    llm_with_tools = llm.bind_tools([retrieve])\n",
    "    response = llm_with_tools.invoke(state[\"messages\"])\n",
    "    # MessagesState appends messages to state instead of overwriting\n",
    "\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "# Step 2: Execute the retrieval.\n",
    "tools = ToolNode([retrieve])\n",
    "\n",
    "\n",
    "# Step 3: Generate a response using the retrieved content.\n",
    "def generate(state: MessagesState):\n",
    "    \"\"\"Generate answer.\"\"\"\n",
    "    # Get generated ToolMessage\n",
    "    recent_tool_messages = []\n",
    "    for message in reversed(state[\"messages\"]):\n",
    "        if message.type == \"tool\":\n",
    "            recent_tool_messages.append(message)\n",
    "        else:\n",
    "            break\n",
    "    tool_messages = recent_tool_messages[::-1]\n",
    "    # Format into prompt\n",
    "    docs_content = \"\\n\\n\".join(doc.content for doc in tool_messages)\n",
    "    system_message_content = (\n",
    "        \"You are an assistant for question-answering tasks. \"\n",
    "        \"Use the following pieces of retrieved context to answer \"\n",
    "        \"the question. If you don't know the answer, say that you \"\n",
    "        \"don't know. Use three sentences maximum and keep the \"\n",
    "        \"answer concise.\"\n",
    "        \"\\n\\n\"\n",
    "        f\"{docs_content}\"\n",
    "    )\n",
    "\n",
    "    conversation_messages = [\n",
    "        message\n",
    "        for message in state[\"messages\"]\n",
    "        if message.type in (\"human\", \"system\")\n",
    "        or (message.type == \"ai\" and not message.tool_calls)\n",
    "    ]\n",
    "    \n",
    "    prompt = [SystemMessage(system_message_content)] + conversation_messages\n",
    "    \n",
    "    # Run\n",
    "    response = llm.invoke(prompt)\n",
    "    return {\"messages\": [response]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "9726047a-fdd5-4e22-a84f-e0a3538ea32a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.graph import END\n",
    "from langgraph.prebuilt import ToolNode, tools_condition\n",
    "from langgraph.graph import MessagesState, StateGraph\n",
    "\n",
    "graph_builder = StateGraph(MessagesState)\n",
    "graph_builder.add_node(query_or_respond)\n",
    "graph_builder.add_node(tools)\n",
    "graph_builder.add_node(generate)\n",
    "\n",
    "graph_builder.set_entry_point(\"query_or_respond\")\n",
    "graph_builder.add_conditional_edges(\n",
    "    \"query_or_respond\",\n",
    "    tools_condition,\n",
    "    {END: END, \"tools\": \"tools\"},\n",
    ")\n",
    "graph_builder.add_edge(\"tools\", \"generate\")\n",
    "graph_builder.add_edge(\"generate\", END)\n",
    "\n",
    "graph = graph_builder.compile()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "f119004c-4bc3-4a41-b92c-9bd54b77809a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from IPython.display import Image, display\n",
    "\n",
    "# display(Image(graph.get_graph().draw_mermaid_png()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "bf211f21-e75f-4d70-9dfe-137a91268e46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这是|在|2|025年|4月8日|发布的内容。|"
     ]
    }
   ],
   "source": [
    "input_message = \"这是星期几发布的?\"\n",
    "\n",
    "for step,metadata in graph.stream(\n",
    "    {\"messages\": [{\"role\": \"user\", \"content\": input_message}]},\n",
    "    stream_mode=\"messages\",\n",
    "):\n",
    "    # print(metadata[\"langgraph_node\"],step)\n",
    "    if (metadata[\"langgraph_node\"] == \"query_or_respond\" or metadata[\"langgraph_node\"] == \"generate\" ) and (text := step.text()):\n",
    "        print(text, end=\"|\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6bc7de-9ebb-43a6-90d8-241dc28e13b9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
